from allennlp.data.tokenizers import SpacyTokenizer as AllenWordTokenizer
from allennlp.data.tokenizers import CharacterTokenizer as AllenCharacterTokenizer
from allennlp.data.tokenizers import PretrainedTransformerTokenizer as AllenPretrainedTransformerTokenizer
from allennlp.data.tokenizers import WhitespaceTokenizer as AllenWhitespaceTokenizer
from allennlp.data import Token
from spacy.lang.en.stop_words import STOP_WORDS

from transformers import BertTokenizerFast, RobertaTokenizerFast, AlbertTokenizer


class WordTokenizer(AllenWordTokenizer):
    unk = '@@UNKNOWN@@'

    def detokenize(self, tokens):
        tokens = [t for t in tokens]
        return ' '.join(tokens)


class CharacterTokenizer(AllenCharacterTokenizer):
    unk = ' '

    def detokenize(self, tokens):
        tokens = [t for t in tokens]
        return ''.join(tokens)


class PretrainedTransformerTokenizer():
    unk = '[UNK]'

    def __init__(self, model_name):
        if model_name == 'bert-base-uncased':
            self.tokenizer = BertTokenizerFast.from_pretrained(model_name)
        elif model_name == 'roberta-base':
            self.tokenizer = RobertaTokenizerFast.from_pretrained(model_name)
        elif model_name == 'albert-base-v2':
            self.tokenizer = AlbertTokenizer.from_pretrained(model_name)
        else:
            raise ValueError(f'{model_name} tokenizer not supported!')

    def tokenize(self, sentence):
        texts = self.tokenizer.tokenize(sentence, add_special_tokens=True)
        text_ids = self.tokenizer.encode(sentence, add_special_tokens=True)
        tokens = []
        for t, id in zip(texts, text_ids):
            tokens.append(Token(text=t, text_id=id))
        return tokens

    def detokenize(self, tokens):
        tokens = [t for t in tokens]
        sentence = ' '.join(tokens[1:-1])
        sentence = sentence.replace(' ##', '')
        return sentence


class StopwordFilter(WordTokenizer):
    STOP_WORDS = STOP_WORDS

    def tokenize(self, text):
        tokens = super().tokenize(text)
        filter_tokens = []
        for t in tokens:
            w = t.text.lower()
            if w not in STOP_WORDS:
                filter_tokens.append(w)
        return filter_tokens

    def filter(self, text):
        tokens = self.tokenize(text)
        return self.detokenize(tokens)


class WhitespaceTokenizer(AllenWhitespaceTokenizer):
    unk = '@@UNKNOWN@@'

    def detokenize(self, tokens):
        tokens = [t for t in tokens]
        return ' '.join(tokens)
